import gym
import numpy as np
from copy import deepcopy as cp
import torch
from mnp_attack import get_mnp_attacked_observation
import pickle
from train_victim_wm import MLP
import os

class AdversaryEnvWrapper(gym.Wrapper):
    def __init__(self, victim_env, args, run_mnp_attack=False):
        super().__init__(victim_env)
        self.other_is_deterministic = False
        self.victim_agent_model = None
        self.victim_env = victim_env # victim_env.venv.envs[0] if hasattr(victim_env, "venv") else victim_env.envs[0]

        self.adv_append_past = args.adv_append_past
        self.adv_append_wm = args.adv_append_wm

        if self.adv_append_wm and args.action_scale == 1:
            self.adv_additive = False
        else:
            self.adv_additive = True

        self.action_space = gym.spaces.Box(low=-1, high=1, shape=victim_env.observation_space.shape)
        
        self.tolerance = args.tolerance
        self.ill_rew_zero_w_tol = args.ill_rew_zero_w_tol

        self.victim_noise_sigma = args.victim_noise_sigma
        self.distance_order = np.inf
        self.append_l2 = args.append_l2
        self.use_reduced_distance = args.use_reduced_distance

        self.run_mnp_attack = run_mnp_attack

        self.illusionary_reward_weight = args.illusionary_reward_weight
        self.action_scale = args.action_scale

        # action scale not required to be larger than one if action is independent
        if self.adv_append_past:
            self.observation_space = self.get_appended_observation_space()
        else:
            self.observation_space = self.env.observation_space

        self.last_done = None
        self.last_true_obs = None
        self.last_infos = None
        self.last_obs_seen = None
        self.last_victim_action = None

        self.lstm_states_victim = None

        self.render = False

        self.frames_true = [] # TODO this should technically be handeled using another recording wrapper...
        self.frames_seen = [] # TODO do this in a later step

        self.max_abs_obs = self.get_real_valued_max_abs_observation(str(self.victim_env))

        self.additional_metrics = ["illusionary_rewards", "illusionary_rewards_scaled", "victim_rewards",
                                   "notin_tol_n", "distance_inf_n", "distance_l2_n",
                                   "distance_l2_to_true_n", "distance_inf_to_true_n", "clipped"]

        # self.exps_collected = dict(
        #     observations=[],
        #     actions=[],
        #     infos=[],
        #     dones=[]
        # )

        self.exps_collected = None

        # self.tolerance_statistics = []



    def get_appended_observation_space(self):

        org_space_low = self.env.observation_space.low
        org_space_high = self.env.observation_space.high

        if isinstance(self.env.action_space, gym.spaces.Box):
            appended_space_low = self.env.action_space.low
            appended_space_high = self.env.action_space.high
        else:
            appended_space_low = np.ones(1,)*np.NINF
            appended_space_high = np.ones(1,)*np.inf

        if not self.adv_append_wm:
            new_space_low = np.concatenate((org_space_low, org_space_low, appended_space_low))
            new_space_high = np.concatenate((org_space_high, org_space_high, appended_space_high))
        else:
            new_space_low = np.concatenate((org_space_low, org_space_low, appended_space_low, org_space_low))
            new_space_high = np.concatenate((org_space_high, org_space_high, appended_space_high, org_space_high))

        new_space = gym.spaces.Box(low=new_space_low, high=new_space_high)

        return new_space

    def get_obs_adversary(self, obs_true, last_obs_seen, last_action):

        if last_obs_seen is None:
            is_start = True
            last_obs_seen = np.ones(self.env.observation_space.shape)*-100
        else:
            is_start = False

        if last_action is None:
            last_action = np.ones(self.env.action_space.shape)*-100

        last_action = np.expand_dims(last_action, axis=0) if len(last_action.shape) == 0 else last_action

        obs_combined = [obs_true]

        if self.adv_append_past:
            obs_combined.append(last_obs_seen)
            obs_combined.append(last_action)

        if self.adv_append_wm and is_start:
            obs_combined.append(np.ones(self.env.observation_space.shape)*-100)
        elif self.adv_append_wm:
            obs_combined.append(self.simulate_step(initial_observation=last_obs_seen, action=last_action[0] if "CartPole" in str(self.victim_env) else last_action))
        else:
            pass

        return np.concatenate(obs_combined)


    def simulate_step(self, initial_observation, action):
        return self.victim_env.simulate_step(initial_observation=initial_observation, action=action)
        

    def reset(self):
        obs = self.env.reset()

        self.lstm_states_victim = None
        self.last_true_obs = obs

        if not ((self.get_initial_state_range_victim()[0] <= obs).all() and (self.get_initial_state_range_victim()[1] >= obs).all()):
            print("WARNING: Initial state is not in the range of the victim environment.")
            print("Initial state: ", obs)
            print("Range: ", self.get_initial_state_range_victim())
        
            raise ValueError

        self.last_done = True
        return self.get_obs_adversary(obs, last_obs_seen=None, last_action=None)

    def is_within_tolerance(self, obs_seen, next_obs_seen):
        obs_seen_normalized = obs_seen / self.max_abs_obs
        next_obs_seen_normalized = next_obs_seen / self.max_abs_obs

        if self.use_reduced_distance:
            obs_seen_normalized = self.reduce_measurement(obs_seen_normalized)
            next_obs_seen_normalized = self.reduce_measurement(next_obs_seen_normalized)

        return np.isclose(obs_seen_normalized, next_obs_seen_normalized, atol=self.tolerance)

    def get_illusionary_reward_truemodel(self, obs_seen):
        if self.last_obs_seen is None:
            reward = -1 * self.get_dist_to_initial_state_range(obs_seen)
            return reward, obs_seen

        next_computed_obs = self.simulate_step(initial_observation=self.last_obs_seen, action=self.last_victim_action)

        if self.ill_rew_zero_w_tol and self.is_within_tolerance(obs_seen, next_computed_obs).all():
            reward = 0
        else:
            # self.tolerance_statistics += list(np.argwhere((self.is_within_tolerance(obs_seen, next_computed_obs)==False)).flatten())
            
            reward = -1 * self.get_distance(obs_seen, next_computed_obs)
            if self.append_l2:
                reward += -1 * self.get_distance(obs_seen, next_computed_obs, distance_order=2)

        return reward, next_computed_obs

    def get_max_abs_reward_victim(self):
        if "CartPole" in str(self.victim_env):
            maximum_abs_reward = 1
        elif "Pendulum" in str(self.victim_env):
            maximum_abs_reward = 16.2736044
        elif "HalfCheetah" in str(self.victim_env) or "Hopper" in str(self.victim_env):
            maximum_abs_reward = 3
        elif "Lunar" in str(self.victim_env):
            maximum_abs_reward = 1 # TODO
        else:
            raise NotImplementedError
        return maximum_abs_reward

    def get_initial_state_range_victim(self):
        
        if "CartPoleAndNoise" in str(self.victim_env):
            min = np.concatenate((-0.05*np.ones(4),-100*np.ones(4)))
            max = np.concatenate((0.05*np.ones(4),100*np.ones(4)))
        elif "CartPole-" in str(self.victim_env):
            min = -0.05*np.ones(4)
            max = 0.05*np.ones(4)
        elif "Pendulum" in str(self.victim_env):
            min = np.array([-1, -1, -1])
            max  = np.array([1, 1, 1])
        elif "HalfCheetah" in str(self.victim_env):
            min = -1 * np.array([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 10, 10, 10, 10, 10, 10, 10, 10, 10])
            max  = np.array([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 10, 10, 10, 10, 10, 10, 10, 10, 10])
        elif "Hopper" in str(self.victim_env):
            min = -0.01 + np.array([1.25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
            max  = 0.01 + np.array([1.25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
        elif "Lunar" in str(self.victim_env):
            min = np.ones(8)*-100 # TODO
            max  = np.ones(8)*100 # TODO
        else:
            raise NotImplementedError

        return min, max

    def get_distance(self, obs_a, obs_b, distance_order=None):
        obs_a = obs_a / self.max_abs_obs
        obs_b = obs_b / self.max_abs_obs
        difference = obs_a - obs_b

        distance = self.get_dist_from_diff(difference, distance_order=distance_order)

        return distance

    def get_dist_to_initial_state_range(self, obs):
        low, high = self.get_initial_state_range_victim()
        diff = np.min((obs-low, np.zeros(low.shape)),axis=0) + np.max((obs-high, np.zeros(low.shape)),axis=0)
        dist = self.get_dist_from_diff(diff)
        return dist

    def reduce_measurement(self, measurement):
    
        if "Hopper" in str(self.victim_env):
            return measurement[:5]
        elif "Cheetah" in str(self.victim_env):
            return measurement[[0, 1, 2, 3, 8, 9, 10, 11]]
        else:
            raise NotImplementedError


    def get_dist_from_diff(self, diff, distance_order=None):
        if distance_order is None:
            distance_order = self.distance_order

        if self.use_reduced_distance:
            diff = self.reduce_measurement(diff)
            
        total_distance = np.linalg.norm(diff, ord=distance_order)
        return total_distance


    def clip_obs(self, obs):
        obs_in = cp(obs)

        if self.last_obs_seen is None:
            obs = np.maximum(obs, self.get_initial_state_range_victim()[0])
            obs = np.minimum(obs, self.get_initial_state_range_victim()[1])
        else:
            obs = np.maximum(obs, -1*self.get_real_valued_max_abs_observation(str(self.victim_env)))
            obs = np.minimum(obs, self.get_real_valued_max_abs_observation(str(self.victim_env)))

        clipped = not (obs == obs_in).all()

        return obs, clipped

    @staticmethod
    def get_real_valued_max_abs_observation(env_str):

        ratio_maximum_position_to_maximum_velocity = 10

        if "CartPoleAndNoise" in str(env_str):
            maximum_abs_observation = np.array([4.8, 4.8*ratio_maximum_position_to_maximum_velocity, 1.31314, 1.31314*ratio_maximum_position_to_maximum_velocity,
                                                4.8, 4.8*ratio_maximum_position_to_maximum_velocity, 1.31314, 1.31314*ratio_maximum_position_to_maximum_velocity])
        elif "CartPole" in str(env_str):
            maximum_abs_observation = np.array([4.8, 4.8*ratio_maximum_position_to_maximum_velocity, 1.31314, 1.31314*ratio_maximum_position_to_maximum_velocity])
        elif "HalfCheetah" in str(env_str):
            maximum_abs_observation = np.array([20.50983719, 53.32010321,  1.12667911,  0.94150945,  0.8486054 , 1.1707376 ,  1.2429014 ,  0.99495111,  5.42182194, 20.73862787, 23.97270022, 27.97178025, 29.07163353, 29.69758312, 30.47925965, 41.06976338, 25.21229142])
        elif "Hopper" in str(env_str):
            maximum_abs_observation = np.array([1.62698764,  1.88,  1.93542374,  0.66311227,  0.9, 5.15106928,  3.68316341,  10,  10, 10, 10])
        elif "Pendulum" in str(env_str):
            maximum_abs_observation = np.array([1,1,8])
        elif "Lunar" in str(env_str):
            maximum_abs_observation = np.ones(8) # TODO
        else:
            raise NotImplementedError()

        return maximum_abs_observation


    def get_adversary_reward(self, victim_reward, illusionary_reward, infos):

        victim_reward = victim_reward/self.get_max_abs_reward_victim()

        if self.illusionary_reward_weight == "only":
            adversary_reward = illusionary_reward
        else:
            adversary_reward = -1 * victim_reward + self.illusionary_reward_weight * illusionary_reward

        return adversary_reward


    def write_metrics_to_infos(self, infos, victim_reward, illusionary_reward, obs_seen, obs_true, next_computed_obs, clipped):
        infos["illusionary_rewards"] = illusionary_reward
        infos["illusionary_rewards_scaled"] = self.illusionary_reward_weight * illusionary_reward if self.illusionary_reward_weight != "only" else illusionary_reward
        infos["victim_rewards"] = victim_reward
        infos["notin_tol_n"] = not self.is_within_tolerance(obs_seen, next_computed_obs).all()
        infos["distance_inf_n"] = self.get_distance(obs_seen, next_computed_obs, distance_order=np.inf)
        infos["distance_l2_n"] = self.get_distance(obs_seen, next_computed_obs, distance_order=2)
        infos["distance_inf_to_true_n"] = self.get_distance(obs_seen, obs_true, distance_order=np.inf)
        infos["distance_l2_to_true_n"] = self.get_distance(obs_seen, obs_true, distance_order=2)
        infos["clipped"] = clipped

        return infos

    def process_step(self, action, obs, infos, done):

        if self.exps_collected is not None:

            self.exps_collected["actions"].append(action)
            self.exps_collected["observations"].append(obs)
            self.exps_collected["infos"].append(infos)
            self.exps_collected["dones"].append(done)

    def save_transitions_to_file(self):
        if self.exps_collected is not None:
            fname_exps = "transitions_collected.pickle"
            pickle.dump(self.exps_collected, open(fname_exps, "wb"))
            print(f"Saved transitions to file: {os.getcwd()}/{fname_exps}")


    def get_adversarial_obs(self, action, last_true_obs, last_obs_seen, last_victim_action):

        if self.run_mnp_attack:
            obs_seen = get_mnp_attacked_observation(state=last_true_obs,
                                                    model=self.victim_agent_model,
                                                    max_abs_obs=self.max_abs_obs,
                                                    budget=self.action_scale,
                                                    sigma=self.victim_noise_sigma)
        elif self.action_scale == -1:
            if last_obs_seen is None:
                if "Cart" in str(self.victim_env) or "Pendulum" in str(self.victim_env):
                    return -1 * last_true_obs
                elif "HalfCheetah" in str(self.victim_env):
                    return -1 * last_true_obs
                elif "Hopper" in str(self.victim_env):
                    obs_shown = cp(last_true_obs)
                    obs_shown[0] -= 1.25
                    obs_shown *= -1
                    obs_shown[0] += 1.25
                    return obs_shown
                else:
                    raise ValueError()
            else:
                return self.simulate_step(initial_observation=last_obs_seen, action=last_victim_action)
        
        elif self.adv_additive:
            obs_seen = last_true_obs + action * self.max_abs_obs * self.action_scale
        else:
            obs_seen = action * self.max_abs_obs * self.action_scale

        return obs_seen

    def step(self, action):

        obs_seen = self.get_adversarial_obs(action, self.last_true_obs, self.last_obs_seen, self.last_victim_action)

        illusionary_reward, next_computed_obs = self.get_illusionary_reward_truemodel(obs_seen)

        # clip observation to fit initial distribution
        # obs_seen, clipped = self.clip_obs(obs_seen)
        clipped = False

        # todo move this to rendering wrapper?
        if self.render:

            obs_render_true = cp(self.last_true_obs)
            obs_render_obs = cp(obs_seen)
            self.frames_true.append(self.victim_env.render(mode="rgb_array", observation=obs_render_true))
            self.frames_seen.append(self.victim_env.render(mode="rgb_array", observation=obs_render_obs, is_obs=True))

        # potentially perturb victim observation with noise (robustify)
        if self.victim_noise_sigma != 0:
            obs_used_by_victim = np.copy(obs_seen) + self.victim_noise_sigma*np.random.standard_normal(size=obs_seen.shape) * self.max_abs_obs
        else:
            obs_used_by_victim = obs_seen

        # get action of victim agent
        action_victim, self.lstm_states_victim = self.victim_agent_model.predict(
            observation=obs_used_by_victim,
            state=self.lstm_states_victim,
            episode_start=self.last_done,
            deterministic=self.other_is_deterministic)

        # step the environment with victim action
        next_true_obs, victim_reward, done, infos = self.env.step(action_victim)

        self.process_step(action_victim, obs_used_by_victim, infos, done)

        # get the reward for the adversary agent
        adversary_reward = self.get_adversary_reward(victim_reward, illusionary_reward, infos)

        # update all the metrics that are written
        infos = self.write_metrics_to_infos(infos, victim_reward, illusionary_reward, obs_seen, self.last_true_obs, next_computed_obs, clipped)

        self.last_done = done
        self.last_infos = infos
        self.last_true_obs = next_true_obs
        self.last_obs_seen = obs_seen
        self.last_victim_action = action_victim

        if done:
            # TODO this should go to evaluation wrapper
            if self.render:
                self.frames_true.append(np.zeros(self.frames_true[0].shape, dtype=np.uint8))
                self.frames_seen.append(np.zeros(self.frames_seen[0].shape, dtype=np.uint8))
                print("episode terminated")

            self.last_obs_seen = None
            self.last_victim_action = None

        obs_adv = self.get_obs_adversary(obs_true=next_true_obs, last_obs_seen=self.last_obs_seen, last_action=self.last_victim_action)

        return obs_adv, adversary_reward, done, infos
